function sensor_data = kspaceFirstOrder3D(p0, kgrid, c, rho, t_array, sensor_mask, varargin)
%KSPACEFIRSTORDER3D     3D time-domain simulation of wave propagation.
%
% DESCRIPTION:
%       kspaceFirstOrder3D simulates the time-domain propagation of linear
%       compressional waves through a three-dimensional homogeneous or
%       heterogeneous acoustic medium defined by c and rho given the
%       initial pressure distribution p0. The size and discretisation of
%       the acoustic domain are defined by the k-space grid structure
%       kgrid. At each time step the pressure at the positions defined by
%       sensor_mask are recorded and stored.
%
%       The computation is based on a first-order k-space model which
%       allows a heterogeneous sound speed and density. An absorbing
%       boundary condition (in this case a perfectly matched layer) is
%       implemented to prevent waves that leave one side of the domain
%       being reintroduced from the opposite side (a consequence of using
%       the FFT to compute the spatial derivatives in the wave-equation).
%       This allows infinite domain simulations to be computed using small
%       computational grids.
%
%       For a homogeneous medium the formulation is exact and the time
%       steps are only limited by the effectiveness of the perfectly
%       matched layer (PML). For a heterogeneous medium, the solution
%       represents a leap-frog pseudospectral method with a Laplacian
%       correction that improves the accuracy of computing the temporal
%       derivatives. This allows larger time-steps to be taken without
%       instability compared to conventional pseudospectral time-domain
%       methods. The computational grids are staggered both spatially and
%       temporally. 
%
%       The pressure is returned as an array of time-series at the sensor
%       locations defined by sensor_mask. This can be given either as a
%       binary grid (i.e., a matrix of 1's and 0's the same size as p0)
%       representing the pixels within the computational grid that will
%       collect the data, or as a series of arbitrary Cartesian coordinates
%       within the grid at which the pressure values are calculated at each
%       time-step via interpolation. The Cartesian points must be given as
%       a 3 by N matrix corresponding to the x, y, and z positions,
%       respectively.
%
%       If sensor_mask is given as a set of Cartesian coordinates, the
%       computed sensor_data is returned in the same order. If sensor_mask
%       is given as a binary grid, sensor_data is returned using MATLAB's
%       standard column-wise linear matrix index ordering. In both cases,
%       the recorded data is indexed as sensor_data(sensor position, time).
%       For a binary sensor mask, the pressure values at a particular time
%       can be restored to the sensor positions within the computation grid
%       using unmaskSensorData.
%
%       The code may also be used for time reversal image reconstruction by
%       setting the optional input 'TimeRev' to true. This enforces the
%       pressure given by p0 as a time varying Dirichlet boundary condition
%       over the sensor mask. In this mode, the input pressure p0 must be
%       indexed as p0(sensor position, time). If sensor_mask is given as a
%       set of Cartesian coordinates then p0 must be given in the same
%       order. An equivalent binary sensor mask (computed using nearest
%       neighbour interpolation) is then used to place the pressure values
%       into the computational grid at each time-step. If sensor_mask is
%       given as a binary grid of sensor points then p0 must be given as an
%       array ordered using MATLAB's standard column-wise linear matrix
%       indexing.
%
%       To run the inverse crime, the sensor_data returned when 'TimeRev'
%       is set to false can be used without modification as p0 when
%       'TimeRev' is set to true .
%
% USAGE:
%       sensor_data = kspaceFirstOrder3D(p0, kgrid, c, rho, t_array, sensor_mask) 
%       sensor_data = kspaceFirstOrder3D(p0, kgrid, c, rho, t_array, sensor_mask, ...)  
%
% INPUTS:
%       p0          - map of the initial pressure within the medium over
%                     the discretisation given by kgrid (or the time
%                     varying pressure across the sensor mask if 'TimeRev'
%                     is set to true)  
%       kgrid       - kspace grid structure returned by makeGrid.m
%       c / rho     - maps of the sound speed and density over the medium
%                     discretisation given by kgrid. For homogeneous media,
%                     c and rho may be given as a single value.
%       t_array     - evenly spaced array of time values [s] (t_array can
%                     alternatively be set to 'auto' to automatically
%                     generate the array using makeTime)
%       sensor_mask - binary grid or a set of Cartesian points where the
%                     pressure is recorded at each time step
%
% OPTIONAL INPUTS:
%       Optional 'string', value pairs that may be used to modify the
%       default computational settings.
%
%       'AdaptThresh' - Adaptive boundary condition threshold used when
%                     'TimeRev' is set to 3 (default = 0.005).
%       'CartInterp'- Interpolation mode used to extract the pressure when
%                     a Cartesian sensor mask is given. If set to 'nearest'
%                     and more than one Cartesian point maps to the same
%                     pixel, duplicated data points are discarded and
%                     sensor_data will be returned with less points than
%                     that specified by sensor_mask (default = 'nearest').     
%       'DataCast'  - String input of the data type that variables are cast
%                     to before computation. For example, setting to
%                     'single' will speed up the computation time (due to
%                     the improved efficiency of fftn and iffn for this
%                     data type) at the expense of a loss in precision. 
%       'PlotFreq'  - The number of iterations which must pass before the
%                     simulation plot is updated (default = 10).
%       'PlotLayout'- Boolean controlling whether a three plots are
%                     produced of the initial simulation layout (initial
%                     pressure, sound speed, density) (default = false).
%       'PlotScale' - [min, max] values used to control the scaling for
%                     imagesc (visualisation) (default = [-1 1]).
%       'PlotSim'   - Boolean controlling whether the simulation iterations
%                     are progressively plotted (default = true).
%       'PMLAlpha'  - attenuation in Nepers per m of the absorption within
%                     the perfectly matched layer (default = 4).
%       'PMLInside' - Boolean controlling whether the perfectly matched
%                     layer is inside or outside the grid. Currently only
%                     'PMLInside' = true is supported (default = true).
%       'PMLSize'   - size of the perfectly matched layer in pixels. By
%                     default, the PML is added evenly to all sides of the
%                     grid, however, both PMLSize and PMLAlpha can be given
%                     as 3 element arrays to specify the x, y, and z
%                     properties, respectively. To remove the PML, set the
%                     appropriate PMLAlpha to zero rather than forcing the
%                     PML to be of zero size (default = 20).
%       'Smooth'    - Boolean controlling whether the p0, c, and rho
%                     distributions are smoothed. Smooth can also be given
%                     as a 3 element array to control the smoothing of p0,
%                     c, and rho, respectively (default = true).
%       'TimeRev'   - Boolean controlling whether the code is used in time
%                     reversal mode. If set to true (or 1), the time
%                     reversal is computed by enforcing the pressure values
%                     given by p0 over the sensor surface at each time step
%                     (conventional time reversal). If set to 2, the time
%                     reversal is computed by introducing the pressure
%                     values given by p0 over the sensor surface as a
%                     source term. If set to 3, the time reversal is
%                     computed by using an adaptive boundary condition with
%                     the threshold set by 'AdaptThresh' (default = false).
%
% OUTPUTS:
%       sensor_data - array of pressure time-series recorded at the sensor
%                     positions given by sensor_mask
%
% ABOUT:
%       author      - Bradley Treeby and Ben Cox
%       date        - 7th April 2009
%       last update - 17th July 2009
%       
% This function is part of the k-Wave Toolbox (http://www.k-wave.org)
%
% See also fftn, ifftn, imagesc, kspaceFirstOrder1D, kspaceFirstOrder2D,
% makeGrid, makeTime, smooth, unmaskSensorData 

% KNOWN ISSUES:
%       - status bar does not always update correctly in time reversal mode
%         for earlier versions of matlab
%       - progress bar does not close using 'close all' if ctrl+c forced
%         break is used to exit the simulation

% start the timer
tic;

% =========================================================================
% DEFINE LITERALS
% =========================================================================

% general
NUM_REQ_INPUT_VARIABLES = 6;
COLOR_MAP = getColorMap();

% input defaults
ADAPTIVE_BC_THRESHOLD_DEF = 0.005;
CARTESIAN_INTERP_DEF = 'nearest';
DATA_CAST_DEF = 'off';
PLOT_SIM_DEF = true;
PLOT_FREQ_DEF = 10;
PLOT_SCALE_DEF = [-1 1];
PLOT_LAYOUT_DEF = false;
PML_ALPHA_DEF = 4;
PML_INSIDE_DEF = true;
PML_SIZE_DEF = 20;
SMOOTH_DEF = true;
TIME_REV_DEF = false;

% =========================================================================
% EXTRACT OPTIONAL INPUTS
% =========================================================================

% assign default input parameters
adaptive_bc_threshold = ADAPTIVE_BC_THRESHOLD_DEF;
cartesian_interp = CARTESIAN_INTERP_DEF;
data_cast = DATA_CAST_DEF;
plot_freq = PLOT_FREQ_DEF;
plot_layout = PLOT_LAYOUT_DEF;
plot_scale = PLOT_SCALE_DEF;
plot_sim = PLOT_SIM_DEF;
PML_x_size = PML_SIZE_DEF;
PML_y_size = PML_SIZE_DEF;
PML_z_size = PML_SIZE_DEF;
PML_x_alpha = PML_ALPHA_DEF;
PML_y_alpha = PML_ALPHA_DEF;
PML_z_alpha = PML_ALPHA_DEF;
PML_inside = PML_INSIDE_DEF;
smooth_p0 = SMOOTH_DEF;
smooth_c = SMOOTH_DEF;
smooth_rho = SMOOTH_DEF;
time_rev = TIME_REV_DEF;

% replace with user defined values if provided and check inputs
if nargin < NUM_REQ_INPUT_VARIABLES
    error('Not enough input parameters');
elseif rem(nargin, 2)
    error('Optional input parameters must be given as param, value pairs');    
elseif ~isempty(varargin)
    for input_index = 1:2:length(varargin)
        switch varargin{input_index}
            case 'AdaptThresh'
                adaptive_bc_threshold = varargin{input_index + 1};
            case 'CartInterp'
                cartesian_interp = varargin{input_index + 1}; 
                if ~strcmp(cartesian_interp, 'nearest')
                    error('Optional input CartInterp currently only supports nearest neighbour interpolation');
                end                
            case 'DataCast'
                data_cast = varargin{input_index + 1};     
            case 'PlotFreq'
                plot_freq = varargin{input_index + 1};   
                if ~(numel(plot_freq) == 1 && isnumeric(plot_freq))
                    error('Optional input PlotFreq must be a single numerical value');
                end                
            case 'PlotLayout'
                plot_layout = varargin{input_index + 1};
                if ~islogical(plot_layout)
                    error('Optional input PlotLayout must be boolean');
                end                
            case 'PlotScale'
                plot_scale = varargin{input_index + 1}; 
                if ~(numel(plot_scale) == 2 && isnumeric(plot_scale))
                    error('Optional input PlotScale must be a 2 element numerical array');
                end                
            case 'PlotSim'
                plot_sim = varargin{input_index + 1};   
                if ~islogical(plot_sim)
                    error('Optional input PlotSim must be boolean');
                end                
            case 'PMLAlpha'
                if length(varargin{input_index + 1}) > 3
                    error('Optional input PMLAlpha must be a 1 or 3 element numerical array');
                end                
                PML_x_alpha = varargin{input_index + 1}(1);
                PML_y_alpha = varargin{input_index + 1}(ceil((end + 1)/2));
                PML_z_alpha = varargin{input_index + 1}(end);
            case 'PMLInside'
                PML_inside = varargin{input_index + 1};
                if ~PML_inside
                    error('Optional input PMLInside currently only supports an interior PML');
                end
            case 'PMLSize'
                if length(varargin{input_index + 1}) > 3
                    error('Optional input PMLSize must be a 1 or 3 element numerical array');
                end
                PML_x_size = varargin{input_index + 1}(1);
                PML_y_size = varargin{input_index + 1}(ceil((end + 1)/2));
                PML_z_size = varargin{input_index + 1}(end);                
            case 'Smooth'
                if length(varargin{input_index + 1}) > 3
                    error('Optional input Smooth must be a 1, 2 or 3 element boolean array');
                end
                smooth_p0 = varargin{input_index + 1}(1);
                smooth_c = varargin{input_index + 1}(ceil((end + 1)/2));
                smooth_rho = varargin{input_index + 1}(end);
            case 'TimeRev'
                time_rev = varargin{input_index + 1};
            otherwise
                error('Unknown optional input');
        end
    end
end

% cleanup unused variables
clear *_DEF NUM_REQ_INPUT_VARIABLES;

% switch off layout plot in time reversal mode
plot_layout = plot_layout && ~time_rev;

% update command line status
if ~time_rev
    disp('Running k-space simulation...'); 
else
    disp('Running k-space time reversal...');
end

% check scaling input
if ~time_rev && plot_sim
    if max(p0(:)) > 10*plot_scale(2) || 10*max(p0(:)) < plot_scale(2)
        disp('  WARNING: visualisation plot scale may not be optimal for given p0');
    end
end

% =========================================================================
% SETUP TIME VARIABLE
% =========================================================================

% automatically create a suitable time array if required
if strcmp(t_array, 'auto')
    if ~time_rev
        % create the time array
        [t_array dt] = makeTime(kgrid, max(c(:)));   
    else
        % throw error requesting for t_array
        error('t_array must be given explicitly in time reversal mode');
    end
else
    % extract the time step from the input data
    dt = t_array(2) - t_array(1); 
    
    % check for stability
    if (numDim(c) == 3 || numDim(rho) == 3) && (dt > 0.5*max([kgrid.dz, kgrid.dx, kgrid.dy])/max(c(:)))
        disp('  WARNING: time step may be too large for a stable simulation');
    end
end

% setup the time index variable
if ~time_rev
    index_start = 1;
    index_step = 1;
    index_end = length(t_array);
else
    index_start = length(t_array);
    index_step = -1;
    index_end = 1;
end

% =========================================================================
% CHECK SENSOR MASK INPUT
% =========================================================================

% switch off Cartesian reorder flag
reorder_data = false;

% check if sensor mask is a binary grid or a set of interpolation points
if numDim(sensor_mask) == 3
    
    % check the grid is binary
    if sum(sensor_mask(:)) ~= numel(sensor_mask) - sum(sensor_mask(:) == 0)
        error('sensor_mask must be a binary grid (numeric values must be 0 or 1)');
    end
    
else
      
    % compute an equivalent sensor mask using nearest neighbour
    % interpolation
    [sensor_mask, order_index, reorder_index] = cart2grid(kgrid, sensor_mask);
   
    if ~time_rev && strcmp(cartesian_interp, 'nearest')
        % use the interpolated binary sensor mask but switch on Cartesian
        % reorder flag 
        reorder_data = true;
    else
        % other interpolation methods currently not supported
    end
        
    % reorder the p0 input data in the order of the binary sensor_mask 
    if time_rev
        
        % append the reordering data
        new_col_pos = length(p0(1,:)) + 1;
        p0(:, new_col_pos) = order_index;

        % reorder p0 based on the order_index
        p0 = sortrows(p0, new_col_pos);
        
        % remove the reordering data
        p0 = p0(:, 1:new_col_pos - 1);
        
    end
end

% =========================================================================
% UPDATE COMMAND LINE STATUS
% =========================================================================

disp(['  dt: ' scaleSI(dt) 's, t_end: ' scaleSI(t_array(end)) 's, time steps: ' num2str(length(t_array))]);
[x_sc scale prefix] = scaleSI(min([kgrid.Nx*kgrid.dx, kgrid.Ny*kgrid.dy, kgrid.Nz*kgrid.dz]));
disp(['  input grid size: ' num2str(kgrid.Nx) ' by ' num2str(kgrid.Ny) ' by ' num2str(kgrid.Nz) ' pixels (' num2str(kgrid.x_size*scale) ' by ' num2str(kgrid.y_size*scale) ' by ' num2str(kgrid.z_size*scale) prefix 'm)']); 

% =========================================================================
% PREPARE COMPUTATIONAL GRIDS
% =========================================================================

% smooth p0 distribution if required and then restore maximum magnitude
if smooth_p0 && ~time_rev
    disp('  smoothing p0 distribution...');      
    p0 = smooth(p0, kgrid, true);
end

% expand grid if the PML is set to be outside the input grid
if ~PML_inside
    % 'PML_inside = false not currently supported
else
    % create indexes to place the source input exactly into the simulation
    % grid
    x1 = 1;
    x2 = kgrid.Nx;
    y1 = 1;
    y2 = kgrid.Ny;
    z1 = 1;
    z2 = kgrid.Nz;
end

% select reference sound speed based on heterogeneity maps
c0 = max(c(:));

% smooth c distribution if required
if smooth_c && numDim(c) == 3
    disp('  smoothing c distribution...');      
    c = smooth(c, kgrid);
end
    
% smooth rho distribution if required
if smooth_rho && numDim(rho) == 3
    disp('  smoothing rho distribution...');      
    rho = smooth(rho, kgrid);
end

% =========================================================================
% PREPARE STAGGERED DENSITY AND PML GRIDS
% =========================================================================

% create the staggered grids
x_sg = kgrid.x + kgrid.dx/2;
y_sg = kgrid.y + kgrid.dy/2;
z_sg = kgrid.z + kgrid.dz/2;

% interpolate the values of the density at the staggered grid locations
% where r1 = (x + dx/2, y, z), r2 = (x, y + dy/2, z), r3 = (x, y, z + dz/2)
% values outside of the interpolation range are replaced with their
% original values 
if numDim(rho) == 3
    % rho is heterogeneous
    rho_r1 = interpn(kgrid.z, kgrid.x, kgrid.y,  rho, kgrid.z, x_sg, kgrid.y, '*linear');
    rho_r1(isnan(rho_r1)) = rho(isnan(rho_r1));
    rho_r2 = interpn(kgrid.z, kgrid.x, kgrid.y, rho, kgrid.z, kgrid.x, y_sg, '*linear');
    rho_r2(isnan(rho_r2)) = rho(isnan(rho_r2));
    rho_r3 = interpn(kgrid.z, kgrid.x, kgrid.y, rho, z_sg, kgrid.x, kgrid.y,  '*linear');
    rho_r3(isnan(rho_r3)) = rho(isnan(rho_r3));
    
else
    % rho is homogeneous
    rho_r1 = rho;
    rho_r2 = rho;
    rho_r3 = rho;
end

% define the location of the perfectly matched layer within the grid
x0_min = kgrid.x(1) + PML_x_size*kgrid.dx;
x0_max = kgrid.x(end) - PML_x_size*kgrid.dx;
y0_min = kgrid.y(1) + PML_y_size*kgrid.dy;
y0_max = kgrid.y(end) - PML_y_size*kgrid.dy;
z0_min = kgrid.z(1) + PML_z_size*kgrid.dz;
z0_max = kgrid.z(end) - PML_z_size*kgrid.dz;

% set the PML attenuation over the pressure (regular) grid
ax = PML_x_alpha*(c0/kgrid.dx)*((kgrid.x - x0_max)./(kgrid.x(end) - x0_max)).^4.*(kgrid.x >= x0_max)...
    + PML_x_alpha*(c0/kgrid.dx)*((kgrid.x - x0_min)./(kgrid.x(1) - x0_min)).^4.*(kgrid.x <= x0_min);

ay = PML_y_alpha*(c0/kgrid.dy)*((kgrid.y - y0_max)./(kgrid.y(end) - y0_max)).^4.*(kgrid.y >= y0_max)...
    + PML_y_alpha*(c0/kgrid.dy)*((kgrid.y - y0_min)./(kgrid.y(1) - y0_min)).^4.*(kgrid.y <= y0_min);

az = PML_z_alpha*(c0/kgrid.dz)*((kgrid.z - z0_max)./(kgrid.z(end) - z0_max)).^4.*(kgrid.z >= z0_max)...
    + PML_z_alpha*(c0/kgrid.dz)*((kgrid.z - z0_min)./(kgrid.z(1) - z0_min)).^4.*(kgrid.z <= z0_min);

% set the PML attenuation over the velocity (staggered) grid
ax_sg = PML_x_alpha*(c0/kgrid.dx)*((x_sg - x0_max)./(kgrid.x(end) - x0_max)).^4.*(x_sg >= x0_max)...
    + PML_x_alpha*(c0/kgrid.dx)*((x_sg - x0_min)./(kgrid.x(1) - x0_min)).^4.*(x_sg <= x0_min);

ay_sg = PML_y_alpha*(c0/kgrid.dy)*((y_sg - y0_max)./(kgrid.y(end) - y0_max)).^4.*(y_sg >= y0_max)...
    + PML_y_alpha*(c0/kgrid.dy)*((y_sg - y0_min)./(kgrid.y(1) - y0_min)).^4.*(y_sg <= y0_min);

az_sg = PML_z_alpha*(c0/kgrid.dz)*((z_sg - z0_max)./(kgrid.z(end) - z0_max)).^4.*(z_sg >= z0_max)...
    + PML_z_alpha*(c0/kgrid.dz)*((z_sg - z0_min)./(kgrid.z(1) - z0_min)).^4.*(z_sg <= z0_min);

% precompute shift and absorbing boundary condition operators
x_shift = exp(1i*kgrid.kx*kgrid.dx/2);
y_shift = exp(1i*kgrid.ky*kgrid.dy/2);
z_shift = exp(1i*kgrid.kz*kgrid.dz/2);
x_shift_min = exp(-1i*kgrid.kx*kgrid.dx/2);
y_shift_min = exp(-1i*kgrid.ky*kgrid.dy/2);
z_shift_min = exp(-1i*kgrid.kz*kgrid.dz/2);
abc_x = exp(-ax_sg*dt/2);
abc_y = exp(-ay_sg*dt/2);
abc_z = exp(-az_sg*dt/2);
abc_x_alt = exp(-ax*dt);
abc_y_alt = exp(-ay*dt);
abc_z_alt = exp(-az*dt);

% cleanup unused variables
clear ax* ay* az* x0_min x0_max x_sg y0_min y0_max y_sg z0_min z0_max z_sg PML*;

% =========================================================================
% PREPARE DATA MASKS AND STORAGE VARIABLES
% =========================================================================

% create mask indices
sensor_mask_ind  = find(sensor_mask ~= 0);

% create storage and scaling variables
switch time_rev
    case 0
        
        % preallocate storage variables
        sensor_data = zeros(sum(sensor_mask(:)), length(t_array));
        
    case 2
        
        % extract the sound speed at sensor location
        if length(c) > 1
            c_sens = c(sensor_mask ~= 0);
        else
            c_sens = c;
        end

        % assign source term scale values       
        time_rev_x_scale = (2*c_sens*dt)/(3*kgrid.dx);
        time_rev_y_scale = (2*c_sens*dt)/(3*kgrid.dy);
        time_rev_z_scale = (2*c_sens*dt)/(3*kgrid.dz);
        
        % cleanup unused variables
        clear c_sens;
        
    case 3
        
        % preallocate the sensor variable
        p_sensor = zeros(1, kgrid.Nx*kgrid.Ny*kgrid.Nz);

        % precompute the threshold index
        p_sensor_threshold_index  = abs(p0) >= adaptive_bc_threshold;
        p_sensor_threshold_index  = sum(p_sensor_threshold_index);
        
        % set pressure values below threshold to zero
        p0(abs(p0) < adaptive_bc_threshold) = 0;        

end

% =========================================================================
% SET INITIAL CONDITIONS
% =========================================================================

% set the initial pressure and velocity within the domain to zero
ux_r1 = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);
uy_r2 = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);
uz_r3 = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);
px = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);
py = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);
pz = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);
p = zeros(kgrid.Nz, kgrid.Nx, kgrid.Ny);

% define the modified first order k-space derivative operators
ddx_k = 1i*kgrid.kx.*sinc(c0*dt*kgrid.k/2);
ddy_k = 1i*kgrid.ky.*sinc(c0*dt*kgrid.k/2);
ddz_k = 1i*kgrid.kz.*sinc(c0*dt*kgrid.k/2);

% set initial pressure if in forward mode
if ~time_rev
    px(z1:z2, x1:x2, y1:y2) = p0/3;
    py(z1:z2, x1:x2, y1:y2) = p0/3;
    pz(z1:z2, x1:x2, y1:y2) = p0/3;
end

% pre-shift variables used as frequency domain multipliers
ddx_k = ifftshift(ddx_k);
ddy_k = ifftshift(ddy_k);
ddz_k = ifftshift(ddz_k);
x_shift = ifftshift(x_shift);
y_shift = ifftshift(y_shift);
z_shift = ifftshift(z_shift);
x_shift_min = ifftshift(x_shift_min);
y_shift_min = ifftshift(y_shift_min);
z_shift_min = ifftshift(z_shift_min);

% =========================================================================
% PREPARE VISUALISATIONS
% =========================================================================

% pre-compute suitable axes scaling factor
if plot_layout || plot_sim
    [x_sc scale prefix] = scaleSI(max([kgrid.x(1,:), kgrid.z(:,1)'])); 
end

% plot the simulation layout
if plot_layout
    
    % initial pressure
    figure;
    planeplot(kgrid, p0, 'Initial Pressure: ', plot_scale, scale, prefix, COLOR_MAP);
    
    % plot c if heterogeneous
    if numDim(c) == 3
        c_plot_scale = [0.5*min(c(:)), max(c(:))/0.5];
        figure;
        planeplot(kgrid, c, 'c: ', c_plot_scale, scale, prefix, COLOR_MAP);
    end

    % plot rho if heterogeneous
    if numDim(rho) == 3    
        rho_plot_scale = [0.5*min(rho(:)), max(rho(:))/0.5];
        figure;
        planeplot(kgrid, rho, 'rho: ', rho_plot_scale, scale, prefix, COLOR_MAP);
    end

end

% initialise the figures used for animation
if plot_sim
    img = figure;
    if ~time_rev
        pbar = waitbar(0, 'Computing Pressure Field');
    else
        pbar = waitbar(0, 'Computing Time Reversed Field');
    end
end 

% =========================================================================
% DATA CASTING
% =========================================================================

if ~strcmp(data_cast, 'off');
    
    % update command line status
    disp(['  casting variables to ' data_cast ' type...']);

    % cast computation variables to data_cast type
    eval(['x_shift = ' data_cast '(x_shift);']);
    eval(['y_shift = ' data_cast '(y_shift);']);
    eval(['z_shift = ' data_cast '(z_shift);']);
    eval(['x_shift_min = ' data_cast '(x_shift_min);']);
    eval(['y_shift_min = ' data_cast '(y_shift_min);']);    
    eval(['z_shift_min = ' data_cast '(z_shift_min);']);
    eval(['ux_r1 = ' data_cast '(ux_r1);']);
    eval(['uy_r2 = ' data_cast '(uy_r2);']);    
    eval(['uz_r3 = ' data_cast '(uz_r3);']);
    eval(['px = ' data_cast '(px);']);
    eval(['py = ' data_cast '(py);']);
    eval(['pz = ' data_cast '(pz);']);
    eval(['ddx_k = ' data_cast '(ddx_k);']);
    eval(['ddy_k = ' data_cast '(ddy_k);']);    
    eval(['ddz_k = ' data_cast '(ddz_k);']);
    eval(['abc_x = ' data_cast '(abc_x);']);
    eval(['abc_y = ' data_cast '(abc_y);']);
    eval(['abc_z = ' data_cast '(abc_z);']);
    eval(['abc_x_alt = ' data_cast '(abc_x_alt);']);
    eval(['abc_y_alt = ' data_cast '(abc_y_alt);']);    
    eval(['abc_z_alt = ' data_cast '(abc_z_alt);']);
    eval(['dt = ' data_cast '(dt);']);
    eval(['rho = ' data_cast '(rho);']);
    eval(['rho_r1 = ' data_cast '(rho_r1);']);
    eval(['rho_r2 = ' data_cast '(rho_r2);']);
    eval(['rho_r3 = ' data_cast '(rho_r3);']);    
    eval(['c = ' data_cast '(c);']);
    eval(['p = ' data_cast '(p);']);
    eval(['sensor_mask_ind  = ' data_cast '(sensor_mask_ind);']);
    
    % cast variables only used in forward simulation
    if ~time_rev
        eval(['sensor_data = ' data_cast '(sensor_data);']);

    % cast variables used in time reversal to data_cast type    
    else
        
        eval(['p0  = ' data_cast '(p0);']);  
    
        % additional variables used in source mode
        if time_rev == 2
            eval(['time_rev_x_scale = ' data_cast '(time_rev_x_scale);']);  
            eval(['time_rev_y_scale = ' data_cast '(time_rev_y_scale);']);         
            eval(['time_rev_z_scale = ' data_cast '(time_rev_z_scale);']);          
        end
        
        % additional variables used in threshold mode
        if time_rev == 3
            eval(['p_sensor  = ' data_cast '(p_sensor);']);
            eval(['p_sensor_threshold_index  = ' data_cast '(p_sensor_threshold_index );']);
        end
    end
end

% =========================================================================
% LOOP THROUGH TIME STEPS
% =========================================================================

% update command line status
disp(['  precomputation completed in ' scaleTime(toc)]);
disp('  starting time loop...');

% restart timing variable
tic;

% set adaptive boundary condition loop flag
if time_rev == 3
    null_boundary_condition = true;
    disp('  skipping time steps with no boundary input...');
else
    null_boundary_condition = false;
end

for t_index = index_start:index_step:index_end
    
    % enforce time reversal bounday condition
    switch time_rev
        case 1          
            px(sensor_mask_ind) = p0(:, t_index)/3;
            py(sensor_mask_ind) = p0(:, t_index)/3;
            pz(sensor_mask_ind) = p0(:, t_index)/3;
        case 2
            px(sensor_mask_ind) = px(sensor_mask_ind) + time_rev_x_scale.*p0(:, t_index);
            py(sensor_mask_ind) = py(sensor_mask_ind) + time_rev_y_scale.*p0(:, t_index);
            pz(sensor_mask_ind) = pz(sensor_mask_ind) + time_rev_z_scale.*p0(:, t_index);            
        case 3
            
            % check if any values above the threshold were found
            if p_sensor_threshold_index(t_index)

                % update the boundary condition parameter
                null_boundary_condition = false;

                % place the boundary pressure values into a larger grid
                p_sensor(sensor_mask_ind) = p0(:, t_index);
                p_sensor_index = find(p_sensor ~= 0);
                
                % apply adaptive boundary condition
                px(p_sensor_index) = p_sensor(p_sensor_index)/3;
                py(p_sensor_index) = p_sensor(p_sensor_index)/3;                
                pz(p_sensor_index) = p_sensor(p_sensor_index)/3;        

            end
    end    
    
    % skip loop in time reversal mode 
    if ~null_boundary_condition

        % calculate dp/dx and dp/dy for t + dt using p = px + pz
        p_k = fftn(px + py + pz);
        dpdx = real(ifftn(  ddx_k .* p_k .* x_shift  ));
        dpdy = real(ifftn(  ddy_k .* p_k .* y_shift  ));    
        dpdz = real(ifftn(  ddz_k .* p_k .* z_shift  ));

        % calculate ux and uz for t + dt/2 using dp/dx and dp/dz
        ux_r1 = real(abc_x .* (  abc_x.*ux_r1 - dt./rho_r1.*dpdx  ));
        uy_r2 = real(abc_y .* (  abc_y.*uy_r2 - dt./rho_r2.*dpdy  ));
        uz_r3 = real(abc_z .* (  abc_z.*uz_r3 - dt./rho_r3.*dpdz  ));

        % calculate dux/dx and duz/dz for t + dt/2
        duxdx = real(ifftn(  ddx_k .* fftn(ux_r1) .* x_shift_min  ));
        duydy = real(ifftn(  ddy_k .* fftn(uy_r2) .* y_shift_min  ));        
        duzdz = real(ifftn(  ddz_k .* fftn(uz_r3) .* z_shift_min  ));     

        % update pressure values for t + dt
        px = real(abc_x_alt .* (  px - dt.*rho.*c.^2.*duxdx  ));
        py = real(abc_y_alt .* (  py - dt.*rho.*c.^2.*duydy  ));        
        pz = real(abc_z_alt .* (  pz - dt.*rho.*c.^2.*duzdz  ));         
        
        % extract required data
        if ~time_rev
            p = px + py + pz;
            p_array = reshape(p, [], 1, 1);           
            sensor_data(:, t_index) = p_array(sensor_mask_ind);
        end

        % plot data if required
        if plot_sim && rem(t_index, plot_freq) == 0  
            
            % sum pressure components if not already computed
            if time_rev
                p = px + py + pz;
            end            
            
            % update progress bar
            waitbar(t_index/length(t_array));
            drawnow;
            
            % add sensor mask to picture
            p(sensor_mask_ind) = plot_scale(2);
            
            % update plot
            planeplot(kgrid, double(p), '', plot_scale, scale, prefix, COLOR_MAP);

        end
    end
end

% =========================================================================
% CLEAN UP
% =========================================================================

% save the final pressure field if in time reversal mode
if time_rev
    p = px + py + pz;
    sensor_data = p(z1:z2, x1:x2, y1:y2);
end

% cast pressure variable back to double if required
if ~strcmp(data_cast, 'off')
    sensor_data = double(sensor_data);
end

% reorder the sensor points if binary sensor mask was used for Cartesian
% sensor mask nearest neighbour interpolation
if reorder_data
    
    % update command line status
    disp('  reordering Cartesian measurement data...');
    
    % append the reordering data
    new_col_pos = length(sensor_data(1,:)) + 1;
    sensor_data(:, new_col_pos) = reorder_index;

    % reorder p0 based on the order_index
    sensor_data = sortrows(sensor_data, new_col_pos);

    % remove the reordering data
    sensor_data = sensor_data(:, 1:new_col_pos - 1);
    
end

% clean up used figures
if plot_sim
    close(img);
    close(pbar);
end

% update command line status
disp(['  computation completed in ' scaleTime(toc)]);

function planeplot(kgrid, data, data_title, plot_scale, scale, prefix, color_map)
% Subfunction to produce a plot of a three-dimensional matrix through the
% three central planes

subplot(2, 2, 1), imagesc(kgrid.x(1,:,1)*scale, kgrid.z(:,1,1)*scale, squeeze(data(:, :, kgrid.Ny/2)), plot_scale);
title([data_title 'x-z plane']);
axis image;
subplot(2, 2, 2), imagesc(kgrid.y(1,1,:)*scale, kgrid.z(:,1,1)*scale, squeeze(data(:, kgrid.Nx/2, :)), plot_scale);
title('y-z plane');
axis image;
xlabel(['(All axes in ' prefix 'm)']);
subplot(2, 2, 3), imagesc(kgrid.x(1,:,1)*scale, kgrid.y(1,1,:)*scale, squeeze(data(kgrid.Nz/2, :, :)).', plot_scale);
title('x-y plane');
axis image;
colormap(color_map); 
drawnow;